import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl

# Set better aesthetics
plt.style.use('seaborn-v0_8-whitegrid')
# Using default font family
mpl.rcParams['font.size'] = 14
mpl.rcParams['axes.linewidth'] = 1.5
mpl.rcParams['axes.edgecolor'] = '#333333'

# Read data from the TSV file
import pandas as pd

# Load the data
data = pd.read_csv('data.tsv', sep='\t')

# Extract steps and all columns
steps = data['Step'].values

# Define column names and pretty labels
columns = ['64ept-8tpk-2itr', '64ept-8tpk-1itr', '64ept-16tpk-1itr', '64ept-24tpk-1itr']
labels = ['CoE-2(8/64)', 'MoE(8/64)', 'MoE(16/64)', 'MoE(24/64)']

# Plot with better aesthetics
plt.figure(figsize=(12*0.7, 7*0.7))

# Find the index where steps >= 100
start_idx = np.where(steps >= 100)[0]
start_idx = 0 if len(start_idx) == 0 else start_idx[0]

# Define colors and markers
colors = ['#ff7f0e', '#1f77b4', '#2ca02c', '#d62728']
markers = ['s', 'o', '^', 'D']
linestyles = ['--', '-', '-.', ':']

# Create plot with log scales for all columns
for i, (column, label) in enumerate(zip(columns, labels)):
    plt.loglog(steps[start_idx:], data[column].values[start_idx:], 
               label=label, 
               linestyle=linestyles[i], 
               linewidth=2.5,
               marker=markers[i],
               markersize=5,
               markevery=5,
               color=colors[i])

# Add better labels and title
plt.xlabel('Steps (log scale)', fontweight='bold', fontsize=16)
plt.ylabel('Validation Loss (log scale)', fontweight='bold', fontsize=16)
plt.title('Compute Scaling: #Iteration (CoE) > #TopK (MoE)', 
          fontweight='bold', fontsize=18, pad=20)

# Improve legend
plt.legend(frameon=True, fontsize=12, framealpha=0.7, 
           edgecolor='#333333', loc='upper right')

# Add minor grid lines
plt.grid(True, which='minor', linestyle=':', alpha=0.4)
plt.grid(True, which='major', linestyle='-', alpha=0.5)

# No annotation

# Adjust axis limits to focus on the important part
plt.xlim(370, 1000)
plt.ylim(1, 2)

# Add a subtle background color
plt.gca().set_facecolor('#f8f9fa')

# Save with higher DPI for better quality
plt.tight_layout()
plt.savefig('plot_scal.png', dpi=300, bbox_inches='tight')
plt.show()